Customized RNN

Brief

Learning to define operations in rnn cells under TensorFlow API r1.3.


In [1]:
import tensorflow as tf
import numpy as np

Define MyRnnCell

The following property/methods should be correctly defined for an RNN cell.

  • __call__ (method)
  • output_size (property)
  • state_size (property)

In [2]:
class MyRnnCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self, state_size, dtype):
        self._state_size = state_size
        self._dtype = dtype
        self._W_xh = tf.get_variable(shape=[self._state_size, self._state_size],
                                     dtype=self._dtype, name="W_xh", initializer=tf.truncated_normal_initializer())
        self._W_hh = tf.get_variable(shape=[self._state_size, self._state_size],
                                     dtype=self._dtype, name="W_hh", initializer=tf.truncated_normal_initializer())
        self._W_ho = tf.get_variable(shape=[self._state_size, self._state_size],
                                     dtype=self._dtype, name="W_ho", initializer=tf.truncated_normal_initializer())
        self._b_o = tf.get_variable(shape=[self._state_size], dtype=self._dtype,
                                    name="b_o", initializer=tf.truncated_normal_initializer())
        
    def __call__(self, _input, state, scope=None):
        new_state = tf.tanh(tf.matmul(_input, self._W_xh)+tf.matmul(state, self._W_hh))
        new_output = tf.tanh(tf.matmul(new_state, self._W_ho)+self._b_o)
        return new_output, new_state
    
    @property
    def output_size(self):
        return self._state_size
    
    @property
    def state_size(self):
        return self._state_size

Create an instance of the RNN cell


In [3]:
tf.reset_default_graph()
test_cell = MyRnnCell(2, tf.float64)

Create sample sequence


In [4]:
sample_seq = np.array([[1,0],[0,1],[0,1]],dtype=np.float64)
sample_seq = np.concatenate([sample_seq]*(30), axis=0)
print("Sample sequence:\n{}".format(sample_seq))
train_input = sample_seq[0:5,:]
train_output = sample_seq[1:6,:]
test_input = sample_seq[:-1,:]
test_output = sample_seq[1:,:]


Sample sequence:
[[ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]]

Training & Testing


In [5]:
#state = np.zeros([1, 2])
inputs = tf.placeholder(shape=[None, 2], dtype=tf.float64)
targets = tf.placeholder(shape=[None, 2], dtype=tf.float64)
# One batch only
batch_inputs = tf.reshape(inputs, shape=np.array([1, -1, 2]))
outputs, final_state = tf.nn.dynamic_rnn(test_cell, batch_inputs, dtype=tf.float64)
# de-batch
outputs = tf.reshape(outputs, shape=[-1, 2])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=targets, logits=outputs)
optimize_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
print("Training network")
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(20000):
        sess.run([optimize_op, outputs], feed_dict={inputs: train_input, targets: train_output})
    print("Testing network with input:\n{}".format(test_input))
    print("Expected outputs:\n{}\nNetwork activations:\n{}".format(test_output, 
                                                                   sess.run(outputs, feed_dict={inputs: test_input})))


Training network
Testing network with input:
[[ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]]
Expected outputs:
[[ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]
 [ 1.  0.]
 [ 0.  1.]
 [ 0.  1.]]
Network activations:
[[-0.97185739  0.9359846 ]
 [-0.99579763  0.98666112]
 [ 0.96687398 -0.87659884]
 [-0.98334563  0.95735965]
 [-0.99231504  0.97912792]
 [ 0.96500562 -0.87144929]
 [-0.98290786  0.95645268]
 [-0.9928629   0.98024123]
 [ 0.96526277 -0.87215392]
 [-0.98296675  0.95657436]
 [-0.99279476  0.98010156]
 [ 0.96523031 -0.8720649 ]
 [-0.98295929  0.95655893]
 [-0.99280349  0.98011943]
 [ 0.96523446 -0.87207628]
 [-0.98296024  0.95656091]
 [-0.99280237  0.98011715]
 [ 0.96523393 -0.87207483]
 [-0.98296012  0.95656065]
 [-0.99280251  0.98011744]
 [ 0.965234   -0.87207501]
 [-0.98296014  0.95656069]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]
 [ 0.96523399 -0.87207499]
 [-0.98296014  0.95656068]
 [-0.9928025   0.98011741]]

In [ ]: